from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request, HTTPException
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
from pydantic_settings import BaseSettings
from pydantic import BaseModel, Field
from funasr import AutoModel
import numpy as np
import soundfile as sf
import argparse
import uvicorn
from urllib.parse import parse_qs
import os
import re
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from loguru import logger
import sys
import json
import traceback
import time
import torch
import torchaudio
import ChatTTS
import numpy as np
import pandas as pd
import sounddevice as sd  # 核心播放依赖库


logger.remove()
log_format = "{time:YYYY-MM-DD HH:mm:ss} [{level}] {file}:{line} - {message}"
logger.add(sys.stdout, format=log_format, level="DEBUG", filter=lambda record: record["level"].no < 40)
logger.add(sys.stderr, format=log_format, level="ERROR", filter=lambda record: record["level"].no >= 40)


class Config(BaseSettings):
    sv_thr: float = Field(0.3, description="Speaker verification threshold")
    chunk_size_ms: int = Field(300, description="Chunk size in milliseconds")
    sample_rate: int = Field(16000, description="Sample rate in Hz")
    bit_depth: int = Field(16, description="Bit depth")
    channels: int = Field(1, description="Number of audio channels")
    avg_logprob_thr: float = Field(-0.25, description="average logprob threshold")

config = Config()

emo_dict = {
	"<|HAPPY|>": "😊",
	"<|SAD|>": "😔",
	"<|ANGRY|>": "😡",
	"<|NEUTRAL|>": "",
	"<|FEARFUL|>": "😰",
	"<|DISGUSTED|>": "🤢",
	"<|SURPRISED|>": "😮",
}

event_dict = {
	"<|BGM|>": "🎼",
	"<|Speech|>": "",
	"<|Applause|>": "👏",
	"<|Laughter|>": "😀",
	"<|Cry|>": "😭",
	"<|Sneeze|>": "🤧",
	"<|Breath|>": "",
	"<|Cough|>": "🤧",
}

emoji_dict = {
	"<|nospeech|><|Event_UNK|>": "❓",
	"<|zh|>": "",
	"<|en|>": "",
	"<|yue|>": "",
	"<|ja|>": "",
	"<|ko|>": "",
	"<|nospeech|>": "",
	"<|HAPPY|>": "😊",
	"<|SAD|>": "😔",
	"<|ANGRY|>": "😡",
	"<|NEUTRAL|>": "",
	"<|BGM|>": "🎼",
	"<|Speech|>": "",
	"<|Applause|>": "👏",
	"<|Laughter|>": "😀",
	"<|FEARFUL|>": "😰",
	"<|DISGUSTED|>": "🤢",
	"<|SURPRISED|>": "😮",
	"<|Cry|>": "😭",
	"<|EMO_UNKNOWN|>": "",
	"<|Sneeze|>": "🤧",
	"<|Breath|>": "",
	"<|Cough|>": "😷",
	"<|Sing|>": "",
	"<|Speech_Noise|>": "",
	"<|withitn|>": "",
	"<|woitn|>": "",
	"<|GBG|>": "",
	"<|Event_UNK|>": "",
}

lang_dict =  {
    "<|zh|>": "<|lang|>",
    "<|en|>": "<|lang|>",
    "<|yue|>": "<|lang|>",
    "<|ja|>": "<|lang|>",
    "<|ko|>": "<|lang|>",
    "<|nospeech|>": "<|lang|>",
}

emo_set = {"😊", "😔", "😡", "😰", "🤢", "😮"}
event_set = {"🎼", "👏", "😀", "😭", "🤧", "😷",}

def format_str(s):
	for sptk in emoji_dict:
		s = s.replace(sptk, emoji_dict[sptk])
	return s


def format_str_v2(s):
	sptk_dict = {}
	for sptk in emoji_dict:
		sptk_dict[sptk] = s.count(sptk)
		s = s.replace(sptk, "")
	emo = "<|NEUTRAL|>"
	for e in emo_dict:
		if sptk_dict[e] > sptk_dict[emo]:
			emo = e
	for e in event_dict:
		if sptk_dict[e] > 0:
			s = event_dict[e] + s
	s = s + emo_dict[emo]

	for emoji in emo_set.union(event_set):
		s = s.replace(" " + emoji, emoji)
		s = s.replace(emoji + " ", emoji)
	return s.strip()

def format_str_v3(s):
	def get_emo(s):
		return s[-1] if s[-1] in emo_set else None
	def get_event(s):
		return s[0] if s[0] in event_set else None

	s = s.replace("<|nospeech|><|Event_UNK|>", "❓")
	for lang in lang_dict:
		s = s.replace(lang, "<|lang|>")
	s_list = [format_str_v2(s_i).strip(" ") for s_i in s.split("<|lang|>")]
	new_s = " " + s_list[0]
	cur_ent_event = get_event(new_s)
	for i in range(1, len(s_list)):
		if len(s_list[i]) == 0:
			continue
		if get_event(s_list[i]) == cur_ent_event and get_event(s_list[i]) != None:
			s_list[i] = s_list[i][1:]
		#else:
		cur_ent_event = get_event(s_list[i])
		if get_emo(s_list[i]) != None and get_emo(s_list[i]) == get_emo(new_s):
			new_s = new_s[:-1]
		new_s += s_list[i].strip().lstrip()
	new_s = new_s.replace("The.", " ")
	return new_s.strip()

def contains_chinese_english_number(s: str) -> bool:
    # Check if the string contains any Chinese character, English letter, or Arabic number
    return bool(re.search(r'[\u4e00-\u9fffA-Za-z0-9]', s))


sv_pipeline = pipeline(
    task='speaker-verification',
    model='iic/speech_eres2net_large_sv_zh-cn_3dspeaker_16k',
    model_revision='v1.0.0'
)

asr_pipeline = pipeline(
    task=Tasks.auto_speech_recognition,
    model='iic/SenseVoiceSmall',
    model_revision="master",
    device="cpu",
    disable_update=True
)

model_asr = AutoModel(
    model="iic/SenseVoiceSmall",
    trust_remote_code=True,
    remote_code="./model.py",    
    device="cpu",
    disable_update=True
)

model_vad = AutoModel(
    model="fsmn-vad",
    model_revision="v2.0.4",
    disable_pbar = True,
    max_end_silence_time=500,
    # speech_noise_thres=0.6,
    disable_update=True,
)

reg_spks_files = [
    "speaker/speaker1_a_cn_16k.wav"
]

def reg_spk_init(files):
    reg_spk = {}
    for f in files:
        data, sr = sf.read(f, dtype="float32")
        k, _ = os.path.splitext(os.path.basename(f))
        reg_spk[k] = {
            "data": data,
            "sr":   sr,
        }
    return reg_spk

reg_spks = reg_spk_init(reg_spks_files)

def speaker_verify(audio, sv_thr):
    hit = False
    for k, v in reg_spks.items():
        res_sv = sv_pipeline([audio, v["data"]], sv_thr)
        if res_sv["score"] >= sv_thr:
           hit = True
        logger.info(f"[speaker_verify] audio_len: {len(audio)}; sv_thr: {sv_thr}; hit: {hit}; {k}: {res_sv}")
    return hit, k


def asr(audio, lang, cache, use_itn=False):
    # with open('test.pcm', 'ab') as f:
    #     logger.debug(f'write {f.write(audio)} bytes to `test.pcm`')
    # result = asr_pipeline(audio, lang)
    start_time = time.time()
    result = model_asr.generate(
        input           = audio,
        cache           = cache,
        language        = lang.strip(),
        use_itn         = use_itn,
        batch_size_s    = 60,
    )
    end_time = time.time()
    elapsed_time = end_time - start_time
    logger.debug(f"asr elapsed: {elapsed_time * 1000:.2f} milliseconds")
    return result

app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

@app.exception_handler(Exception)
async def custom_exception_handler(request: Request, exc: Exception):
    logger.error("Exception occurred", exc_info=True)
    if isinstance(exc, HTTPException):
        status_code = exc.status_code
        message = exc.detail
        data = ""
    elif isinstance(exc, RequestValidationError):
        status_code = HTTP_422_UNPROCESSABLE_ENTITY
        message = "Validation error: " + str(exc.errors())
        data = ""
    else:
        status_code = 500
        message = "Internal server error: " + str(exc)
        data = ""

    return JSONResponse(
        status_code=status_code,
        content=TranscriptionResponse(
            code=status_code,
            msg=message,
            data=data
        ).model_dump()
    )

# Define the response model
class TranscriptionResponse(BaseModel):
    code: int
    info: str
    data: str

# 确保导入必要的库
import io
import wave
import numpy as np
import torchaudio  # 用于采样率转换

@app.websocket("/ws/transcribe")
async def websocket_endpoint(websocket: WebSocket):
    try:
        query_params = parse_qs(websocket.scope['query_string'].decode())
        sv = query_params.get('sv', ['false'])[0].lower() in ['true', '1', 't', 'y', 'yes']
        lang = query_params.get('lang', ['auto'])[0].lower()
        
        await websocket.accept()
        chunk_size = int(config.chunk_size_ms * config.sample_rate / 1000)
        audio_buffer = np.array([], dtype=np.float32)
        audio_vad = np.array([], dtype=np.float32)

        cache = {}
        cache_asr = {}
        last_vad_beg = last_vad_end = -1
        offset = 0
        hit = False
        
        buffer = b""
        target_sample_rate = 16000  # 固定16000Hz采样率
        
        while True:
            data = await websocket.receive_bytes()
            buffer += data
            if len(buffer) < 2:
                continue
                
            audio_buffer = np.append(
                audio_buffer, 
                np.frombuffer(buffer[:len(buffer) - (len(buffer) % 2)], dtype=np.int16).astype(np.float32) / 32767.0
            )
            
            buffer = buffer[len(buffer) - (len(buffer) % 2):]
   
            while len(audio_buffer) >= chunk_size:
                chunk = audio_buffer[:chunk_size]
                audio_buffer = audio_buffer[chunk_size:]
                audio_vad = np.append(audio_vad, chunk)
                
                if last_vad_beg > 1:
                    if sv:
                        if not hit:
                            hit, speaker = speaker_verify(audio_vad[int((last_vad_beg - offset) * config.sample_rate / 1000):], config.sv_thr)
                            if hit:
                                response = TranscriptionResponse(
                                    code=2,
                                    info="detect speaker",
                                    data=speaker
                                )
                                await websocket.send_json(response.model_dump())
                    else:
                        response = TranscriptionResponse(
                            code=2,
                            info="detect speech",
                            data=''
                        )
                        await websocket.send_json(response.model_dump())

                res = model_vad.generate(input=chunk, cache=cache, is_final=False, chunk_size=config.chunk_size_ms)
                if len(res[0]["value"]):
                    vad_segments = res[0]["value"]
                    for segment in vad_segments:
                        if segment[0] > -1:
                            last_vad_beg = segment[0]                           
                        if segment[1] > -1:
                            last_vad_end = segment[1]
                        if last_vad_beg > -1 and last_vad_end > -1:
                            last_vad_beg -= offset
                            last_vad_end -= offset
                            offset += last_vad_end
                            beg = int(last_vad_beg * config.sample_rate / 1000)
                            end = int(last_vad_end * config.sample_rate / 1000)
                            logger.info(f"[vad segment] audio_len: {end - beg}")
                            result = None if sv and not hit else asr(audio_vad[beg:end], lang.strip(), cache_asr, True)
                            logger.info(f"asr response: {result}")
                            audio_vad = audio_vad[end:]
                            last_vad_beg = last_vad_end = -1
                            hit = False
                            
                            # -------------------------- 按需求修改回复逻辑：code=1→推理→code=0→发音频 --------------------------
                            try:
                                # 1. 判断是否需要进行ChatTTS推理（仅当有ASR结果或需默认回复时推理）
                                need_tts = True
                                tts_text = "你好啊，我是亥时弈输灯花。~"
                                if result and result[0]['text'].strip():
                                    need_tts = True
                                    tts_text = result[0]['text'].strip().split('>')[-1]+'~'
                                elif (not result) or (result and not result[0]['text'].strip()):
                                    # 无ASR结果时，用默认文本推理（可根据需求调整是否推理）
                                    need_tts = True
                                    tts_text = "你好啊，我是亥时弈输灯花。~"

                                # 2. 仅当需要推理时，回复code=1（推理开始通知）
                                if need_tts:
                                    await websocket.send_json({
                                        "code": 1,
                                        "info": "tts_inference_start",
                                        "msg": "开始生成语音回复"
                                    })
                                    logger.info("已回复code=1：ChatTTS推理开始")

                                    # 3. 执行ChatTTS推理
                                    wavs = chat.infer([tts_text], params_infer_code=params_infer_code)
                                    audio_data = wavs[0]  # 原始24000Hz浮点音频

                                    # 4. 采样率转换（24000Hz→16000Hz）
                                    audio_tensor = torch.tensor(audio_data).unsqueeze(0)
                                    resampler = torchaudio.transforms.Resample(orig_freq=24000, new_freq=target_sample_rate)
                                    resampled_tensor = resampler(audio_tensor)
                                    resampled_audio = resampled_tensor.squeeze(0).numpy()

                                    # 5. 转换为16位单声道PCM二进制
                                    pcm_data = np.int16(resampled_audio * 32767)
                                    binary_audio = pcm_data.tobytes()

                                    # 6. 音频发送前，回复code=0（音频准备就绪通知）
                                    await websocket.send_json({
                                        "code": 0,
                                        "info": "audio_ready_to_send",
                                        "audio_meta": {
                                            "sample_rate": target_sample_rate,
                                            "bits_per_sample": 16,
                                            "channels": 1,
                                            "audio_size": len(binary_audio)
                                        }
                                    })
                                    logger.info("已回复code=0：音频准备就绪，即将发送")

                                    # 7. 发送二进制PCM音频
                                    await websocket.send_bytes(binary_audio)
                                    logger.info(f"已发送16000Hz二进制音频，大小: {len(binary_audio)}字节")
                                    await websocket.send_json({
                                        "code": 3,
                                        "info": "audio_ready_to_send",
                                        "audio_meta": {
                                            "sample_rate": target_sample_rate,
                                            "bits_per_sample": 16,
                                            "channels": 1,
                                            "audio_size": len(binary_audio)
                                        }
                                    })
                                    
                                    # 8. 本地播放（调试用）
                                    print("正在播放音频...")
                                    sd.play(resampled_audio, samplerate=target_sample_rate)
                                    sd.wait()
                                    print("音频播放结束！")
                                    # 6. 音频发送前，回复code=0（音频准备就绪通知）
                                # 9. 不需要推理时，不回复任何内容（符合需求）
                                else:
                                    logger.info("无需进行ChatTTS推理，不回复")

                            except Exception as tts_err:
                                logger.error(f"ChatTTS处理失败: {str(tts_err)}")
                                # 推理异常时可按需回复错误通知（可选）
                                await websocket.send_json({
                                    "code": -1,
                                    "info": "tts_inference_error",
                                    "error_msg": str(tts_err)[:100]  # 限制错误信息长度
                                })

                            # 10. 发送ASR结果（原有逻辑保留）
                            if result is not None:
                                response = TranscriptionResponse(
                                    code=2,  # 避免与code=0/1冲突，修改ASR结果的code为2（原code=0已复用）
                                    info=json.dumps(result[0], ensure_ascii=False),
                                    data=format_str_v3(result[0]['text'])
                                )
                                await websocket.send_json(response.model_dump())
                                
    except WebSocketDisconnect:
        logger.info("WebSocket disconnected")
    except Exception as e:
        logger.error(f"Unexpected error: {e}\nCall stack:\n{traceback.format_exc()}")
        await websocket.close()
    finally:
        audio_buffer = np.array([], dtype=np.float32)
        audio_vad = np.array([], dtype=np.float32)
        cache.clear()
        logger.info("Cleaned up resources after WebSocket disconnect")
if __name__ == "__main__":

    # -------------------------- 1. 基础配置与模型初始化 --------------------------
    # 固定随机种子，确保结果可复现
    torch.manual_seed(100)  # PyTorch随机种子
    np.random.seed(120)     # NumPy随机种子

    # PyTorch优化配置（保障模型运行效率）
    torch._dynamo.config.cache_size_limit = 64
    torch._dynamo.config.suppress_errors = True
    torch.set_float32_matmul_precision('high')


    # 加载说话人嵌入特征（维持原有说话人风格）
    data = pd.read_csv(f"./ChatTTS/speaker/1145.csv", header=None)
    rand_spk = torch.tensor(data.iloc[0], dtype=torch.float32)

    # 初始化并加载ChatTTS模型（本地路径需确保正确）
    chat = ChatTTS.Chat()
    chat.load(source='local', custom_path='./ChatTTS/model', compile=False)

    # -------------------------- 2. 音频生成参数配置（保守稳定） --------------------------
    params_infer_code = ChatTTS.Chat.InferCodeParams(
        spk_emb=rand_spk,       # 固定说话人特征
        top_P=0.7,              # 严格过滤候选词，提升流畅度
        top_K=20,               # 限制候选词数量，降低卡顿概率
        repetition_penalty=1.05 # 轻微惩罚重复内容
    )
    # params_refine_text = {'prompt': '[break_2]'}  # 文本精炼，优化断句


    parser = argparse.ArgumentParser(description="Run the FastAPI app with a specified port.")
    parser.add_argument('--port', type=int, default=27000, help='Port number to run the FastAPI app on.')
    # parser.add_argument('--certfile', type=str, default='path_to_your_SSL_certificate_file.crt', help='SSL certificate file')
    # parser.add_argument('--keyfile', type=str, default='path_to_your_SSL_certificate_file.key', help='SSL key file')
    args = parser.parse_args()
    # uvicorn.run(app, host="0.0.0.0", port=args.port, ssl_certfile=args.certfile, ssl_keyfile=args.keyfile)
    uvicorn.run(app, host="0.0.0.0", port=args.port)
